from pyspark.ml.regression import LinearRegression, LinearRegressionModel
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
from utils.spark_util import col_validation, one_hot_encoding, feature_cols_labelling, label_col_labelling


def train(sc, spark, dataframe, para_list, record, output):
    # --------------------prepare data------------------
    feature_cols = dataframe.columns
    feature_cols.remove(para_list["labelCol"])
    feature_cols.remove('_record_id_')
    print(feature_cols)
    num_cols, cat_cols = col_validation(dataframe, feature_cols)

    if len(cat_cols) != 0:
        dataframe, message, encoded_cat_cols = one_hot_encoding(sc, dataframe, para_list["source"], cat_cols)
        if len(message) != 0:
            output["message"].append(message)
    else:
        encoded_cat_cols = []

    feature_list = num_cols + encoded_cat_cols
    print(feature_list)
    df = feature_cols_labelling(dataframe, feature_list)
    df.show(1)
    df = label_col_labelling(df, para_list["labelCol"])
    df.show(2)
    # train, test = train_test_split(df, para_list["split_rate"])
    # ---------------------model fit---------------------

    lr = LinearRegression(maxIter=10)
    paramGrid = ParamGridBuilder() \
        .addGrid(lr.regParam, [0.1, 0.01]) \
        .addGrid(lr.fitIntercept, [False, True]) \
        .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0]) \
        .build()

    tvs = TrainValidationSplit(estimator=lr,
                               estimatorParamMaps=paramGrid,
                               evaluator=RegressionEvaluator(),
                               trainRatio=0.8)

    # Run TrainValidationSplit, and choose the best set of parameters.
    model = tvs.fit(df)
    print('model fitted')
    # ---------------------model validation---------------

    # test_pred = model.transform(test)

    # ---------------------get metrics--------------------
    # more_info = {}
    # train_summary = model.summary
    #
    # more_info["total_iterations"] = train_summary.totalIterations
    # more_info["df"] = train_summary.degreesOfFreedom
    # more_info["coefficient_standard_errors"] = train_summary.coefficientStandardErrors
    #
    # RMSE = train_summary.rootMeanSquaredError
    # MSE = train_summary.meanSquaredError
    # MAE = train_summary.meanAbsoluteError
    # r2 = train_summary.r2
    # r2adj = train_summary.r2adj
    # pvalues = train_summary.pValues
    # tvalues = train_summary.tValues
    #
    # print(RMSE,MSE,MAE,r2,r2adj,pvalues,tvalues)

    print(model.estimatorParamMaps)
    model.transform(df).show()
    # -----------------------feature selection------------------
